"""Build .docx from paper.md + all output tables."""

import re
from pathlib import Path
from docx import Document
from docx.shared import Inches, Pt, Cm, RGBColor
from docx.enum.text import WD_ALIGN_PARAGRAPH
from docx.enum.table import WD_TABLE_ALIGNMENT

PROJECT = Path("/mnt/c/demographics_capital_flows/sectoral_savings")
PAPER_MD = PROJECT / "paper" / "paper.md"
TABLES_DIR = PROJECT / "output" / "tables"
OUTPUT = PROJECT / "paper" / "sectoral_savings_paper.docx"

TABLE_FILES = [
    ("Table 1: Summary Statistics", "summary_statistics.md"),
    ("Table 2: Sectoral Decomposition", "decomposition.md"),
    ("Table 3: OECD vs non-OECD", "oecd_decomposition.md"),
    ("Table 4: Income Terciles", "income_terciles.md"),
    ("Table 5: Mediation Analysis", "mediation.md"),
    ("Table 6: Lagged Demographics", "robustness.md"),
    ("Table 7: Consumption Decomposition", "consumption_decomposition.md"),
    ("Table 8: Capital Openness Interactions", "kaopen_interactions.md"),
    ("Table 9: Eurozone Subsample", "eurozone.md"),
    ("Table 10: Structural Break", "structural_break.md"),
    ("Table 11: Aging Speed", "aging_speed.md"),
    ("Table 12: Cointegration Tests", "cointegration.md"),
    ("Table 13: Bootstrap Standard Errors", "bootstrap.md"),
    ("Table 14: Placebo Tests", "placebo.md"),
    ("Table 15: Leave-One-Out", "leave_one_out.md"),
    ("Table 16: Regional Jackknife", "regional_jackknife.md"),
    ("Table 17: Investment Decomposition", "investment_decomposition.md"),
    ("Table 18: Fiscal Interactions", "fiscal_interactions.md"),
    ("Table 19: Trade Interactions", "trade_interactions.md"),
    ("Table R1: Effect Size Scaling", "referee_effect_scaling.md"),
    ("Table R2: Winsorized Robustness", "referee_winsorized.md"),
    ("Table R3: S-I vs CA Reconciliation", "referee_si_ca_reconciliation.md"),
    ("Table R4: OECD Government Saving LOO", "referee_oecd_govt_loo.md"),
    ("Table R5: Z Correlations with Age Ratios", "referee_z_correlations.md"),
    ("Table R6: Two-Way FE Robustness", "referee_fe_robustness.md"),
    ("Table R7: Trade Margins", "referee_trade_margins.md"),
]

INLINE_TABLE_MAP = {}
for label, filename in TABLE_FILES:
    m = re.match(r'Table\s+(\S+):', label)
    if m:
        INLINE_TABLE_MAP[m.group(1)] = (label, filename)


def set_cell_text(cell, text, bold=False, size=Pt(9), font_name='Times New Roman'):
    cell.text = ""
    p = cell.paragraphs[0]
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    run = p.add_run(text)
    run.font.size = size
    run.font.name = font_name
    run.bold = bold


def add_md_table(doc, md_text, title=None):
    lines = [l.strip() for l in md_text.strip().split('\n') if l.strip()]
    table_lines = [l for l in lines if '|' in l and not l.startswith('#')]
    if not table_lines:
        return
    rows = []
    for line in table_lines:
        cells = [c.strip() for c in line.split('|')]
        cells = [c for c in cells if c != '']
        if all(set(c) <= set('-: ') for c in cells):
            continue
        rows.append(cells)
    if len(rows) < 2:
        return
    if title:
        p = doc.add_paragraph()
        run = p.add_run(title)
        run.bold = True
        run.font.size = Pt(11)
        run.font.name = 'Times New Roman'
    n_cols = max(len(r) for r in rows)
    table = doc.add_table(rows=len(rows), cols=n_cols)
    table.style = 'Light Shading'
    table.alignment = WD_TABLE_ALIGNMENT.CENTER
    for i, row_data in enumerate(rows):
        for j, cell_text in enumerate(row_data):
            if j < n_cols:
                set_cell_text(table.cell(i, j), cell_text, bold=(i == 0))
    doc.add_paragraph()


def parse_md_file(filepath):
    text = filepath.read_text()
    sections = []
    lines = text.split('\n')
    current_title = None
    table_buf = []
    for line in lines:
        if line.startswith('#'):
            if table_buf:
                sections.append(('table', current_title, '\n'.join(table_buf)))
                table_buf = []
            current_title = re.sub(r'^#+\s*', '', line).strip()
        elif '|' in line:
            table_buf.append(line)
        elif line.strip().startswith('*') and not table_buf:
            sections.append(('note', line.strip()))
    if table_buf:
        sections.append(('table', current_title, '\n'.join(table_buf)))
    return sections


def sanitize_math(math_text):
    math_text = re.sub(r'_\{([^}]+)\}', r'_\1', math_text)
    math_text = re.sub(r'\^\{([^}]+)\}', r'^\1', math_text)
    for old, new in [('\\text{', ''), ('\\log', 'log'), ('\\cdot', '\u00b7'),
                     ('\\beta', '\u03b2'), ('\\gamma', '\u03b3'), ('\\delta', '\u03b4'),
                     ('\\Delta', '\u0394'), ('\\alpha', '\u03b1'), ('\\varepsilon', '\u03b5'),
                     ('\\times', '\u00d7'), ('\\pi', '\u03c0'), ('\\phi', '\u03c6'),
                     ('\\sigma', '\u03c3'), ('\\rho', '\u03c1'), ('\\lambda', '\u03bb'),
                     ('\\mu', '\u03bc'), ('\\approx', '\u2248'), ('\\rightarrow', '\u2192'),
                     ('\\%', '%'), ('\\hat', ''), ('\\bar', ''), ('\\overline', '')]:
        math_text = math_text.replace(old, new)
    math_text = re.sub(r'\\[a-zA-Z]+', '', math_text)
    math_text = math_text.replace('{', '').replace('}', '')
    return math_text


def build_docx():
    doc = Document()
    for section in doc.sections:
        section.top_margin = Cm(2.54)
        section.bottom_margin = Cm(2.54)
        section.left_margin = Cm(2.54)
        section.right_margin = Cm(2.54)
    style = doc.styles['Normal']
    style.font.name = 'Times New Roman'
    style.font.size = Pt(11)
    style.paragraph_format.space_after = Pt(6)
    style.paragraph_format.line_spacing = 1.15

    paper_text = PAPER_MD.read_text()
    lines = paper_text.split('\n')
    title_emitted = False
    in_title_block = False
    i = 0
    while i < len(lines):
        line = lines[i]

        if line.startswith('# ') and not line.startswith('## ') and not title_emitted:
            title = line[2:].strip()
            p = doc.add_heading(title, level=0)
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            for run in p.runs:
                run.font.size = Pt(16)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run('Brian Peters')
            run.font.size = Pt(13)
            run.font.name = 'Times New Roman'
            run.bold = True
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run('Independent Researcher\nMarch 2026\nWorking Paper')
            run.font.size = Pt(11)
            run.font.name = 'Times New Roman'
            doc.add_paragraph()
            title_emitted = True
            in_title_block = True
            i += 1
            continue

        if line.startswith('## '):
            in_title_block = False
            heading = line[3:].strip()
            p = doc.add_heading(heading, level=1)
            for run in p.runs:
                run.font.size = Pt(13)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            i += 1
            continue

        if line.startswith('### '):
            heading = line[4:].strip()
            p = doc.add_heading(heading, level=2)
            for run in p.runs:
                run.font.size = Pt(11)
                run.font.name = 'Times New Roman'
                run.font.color.rgb = RGBColor(0, 0, 0)
            i += 1
            continue

        if line.strip().startswith('$$'):
            stripped = line.strip()
            if stripped.endswith('$$') and len(stripped) > 4:
                math_text = stripped[2:-2]
                i += 1
            else:
                math_lines = [stripped.replace('$$', '')]
                i += 1
                while i < len(lines) and '$$' not in lines[i]:
                    math_lines.append(lines[i].strip())
                    i += 1
                if i < len(lines):
                    math_lines.append(lines[i].strip().replace('$$', ''))
                    i += 1
                math_text = ' '.join(l for l in math_lines if l)
            math_text = sanitize_math(math_text)
            p = doc.add_paragraph()
            p.alignment = WD_ALIGN_PARAGRAPH.CENTER
            run = p.add_run(math_text)
            run.font.size = Pt(10)
            run.font.name = 'Cambria Math'
            run.italic = True
            continue

        if '|' in line and line.strip().startswith('|'):
            table_lines = []
            while i < len(lines) and '|' in lines[i]:
                table_lines.append(lines[i])
                i += 1
            add_md_table(doc, '\n'.join(table_lines))
            continue

        if line.strip():
            text = line.strip()
            text = re.sub(r'(\d)\*\*\*(?=[\s,)|]|$)', r'\1' + '\u2042\u2042\u2042', text)
            text = re.sub(r'(\d)\*\*(?!\*)(?=[\s,)|]|$)', r'\1' + '\u2042\u2042', text)
            text = re.sub(r'(\d)\*(?!\*)(?=[\s,)|]|$)', r'\1' + '\u2042', text)
            text = text.replace(' -- ', ' \u2014 ')

            is_bullet = text.startswith('- ')
            num_match = re.match(r'^(\d+)\.\s+', text)

            if is_bullet:
                p = doc.add_paragraph(style='List Bullet')
                text = text[2:]
            elif num_match:
                p = doc.add_paragraph(style='List Number')
                text = text[num_match.end():]
            else:
                p = doc.add_paragraph()

            if in_title_block:
                p.alignment = WD_ALIGN_PARAGRAPH.CENTER

            parts = re.split(r'(\*\*[^*]+\*\*|\*[^*]+\*|\$[^$]+\$)', text)
            for part in parts:
                if part.startswith('**') and part.endswith('**'):
                    inner = part[2:-2].replace('\u2042\u2042\u2042', '***').replace('\u2042\u2042', '**').replace('\u2042', '*')
                    run = p.add_run(inner)
                    run.bold = True
                    run.font.name = 'Times New Roman'
                    run.font.size = Pt(11)
                elif part.startswith('*') and part.endswith('*') and len(part) > 2:
                    inner = part[1:-1].replace('\u2042\u2042\u2042', '***').replace('\u2042\u2042', '**').replace('\u2042', '*')
                    run = p.add_run(inner)
                    run.italic = True
                    run.font.name = 'Times New Roman'
                    run.font.size = Pt(11)
                elif part.startswith('$') and part.endswith('$'):
                    math = sanitize_math(part[1:-1])
                    run = p.add_run(math)
                    run.italic = True
                    run.font.name = 'Cambria Math'
                    run.font.size = Pt(10)
                else:
                    restored = part.replace('\u2042\u2042\u2042', '***').replace('\u2042\u2042', '**').replace('\u2042', '*')
                    run = p.add_run(restored)
                    run.font.name = 'Times New Roman'
                    run.font.size = Pt(11)

        i += 1

    # Appendix: Tables
    doc.add_page_break()
    p = doc.add_heading('Appendix: Tables', level=0)
    p.alignment = WD_ALIGN_PARAGRAPH.CENTER
    for run in p.runs:
        run.font.size = Pt(16)
        run.font.name = 'Times New Roman'
        run.font.color.rgb = RGBColor(0, 0, 0)

    for label, filename in TABLE_FILES:
        filepath = TABLES_DIR / filename
        if not filepath.exists():
            continue
        doc.add_page_break()
        p = doc.add_heading(label, level=1)
        for run in p.runs:
            run.font.size = Pt(13)
            run.font.name = 'Times New Roman'
            run.font.color.rgb = RGBColor(0, 0, 0)
        sections = parse_md_file(filepath)
        for section in sections:
            if section[0] == 'table':
                _, title, md = section
                add_md_table(doc, md, title=title if title != label else None)
            elif section[0] == 'note':
                p = doc.add_paragraph()
                note_text = re.sub(r'\*+([^*]+)\*+', r'\1', section[1])
                run = p.add_run(note_text)
                run.italic = True
                run.font.size = Pt(9)
                run.font.name = 'Times New Roman'

    doc.save(str(OUTPUT))
    print(f"Saved: {OUTPUT}")
    print(f"Size: {OUTPUT.stat().st_size / 1024:.0f} KB")


if __name__ == '__main__':
    build_docx()
